"""
    Script to evaluate score models (conditional with a linear reward model and unconditional)
"""

# pylint: disable=no-name-in-module

import copy
import json
from pathlib import Path

import hydra
import matplotlib.pyplot as plt
import torch
from omegaconf import DictConfig

from diffusion_bandit import utils
from diffusion_bandit.conditional_generation import ScoreRewardModel
from diffusion_bandit.dataset_generation import distance_to_sphere
from diffusion_bandit.diffusion import DiffusionProcess
from diffusion_bandit.neural_networks.shape_reward_nets import (
    get_ground_truth_reward_model,
)
from diffusion_bandit.samplers import get_sampler

# pylint: disable=missing-function-docstring


@hydra.main(
    version_base=None, config_path="configs", config_name="evaluate_shape_score_model"
)
def main(config: DictConfig) -> None:
    utils.seeding.seed_everything(config)

    # Setup
    model_load_path = Path(config.outputs_dir) / f"{config.names.score_model}.pth"
    saved_dict = torch.load(model_load_path, weights_only=False)
    device = torch.device(config.sampler.device)

    # Extract parameters from saved_dict
    dataset_config = saved_dict["dataset_config"]["dataset"]
    d_ext, radius, surface = [
        dataset_config[key] for key in ["d_ext", "radius", "surface"]
    ]
    projector, x_data, score_model, beta_min, beta_max = [
        saved_dict[key]
        for key in ["projector", "x_data", "score_model", "beta_min", "beta_max"]
    ]

    # Set up distance function
    manifold_distance_fn = distance_to_sphere

    diffusion_process = DiffusionProcess(
        beta_min=beta_min,
        beta_max=beta_max,
    )

    # Get sampler
    sampler = get_sampler(
        name=config.names.sampler,
        shape=[d_ext],
        diffusion_process=diffusion_process,
        **config.sampler,
    )

    # Set up reward model
    reward_model = get_ground_truth_reward_model(
        d_ext=d_ext,
        projector=projector,
        radius=radius,
        surface=surface,
        name=config.reward_model.name,
        diffusion_process=diffusion_process,
        score=copy.deepcopy(score_model),
    )

    theta = torch.randn(d_ext, device=device)
    reward_model.layer.weight.data = theta.reshape(1, -1)

    x_data = torch.FloatTensor(x_data)
    rewards = reward_model(x_data)

    # radii = torch.norm(x_data @ projector - offset, dim=1, keepdim=True)
    # plot_reward_histogram(rewards, radii, config)

    # Model setup
    for model in [score_model, reward_model]:
        model.to(device)
    reward_model.train()
    score_model.eval()
    for param in reward_model.parameters():
        param.requires_grad = True

    # Initialize eval dictionary
    results = {}

    # Unconditional generation
    paths = sampler.sample(
        score_model=score_model,
        batch_size=config.sampler.batch_size,
        num_steps=config.sampler.num_steps,
    )

    distances = torch.stack(
        [manifold_distance_fn(path, projector, radius, surface) for path in paths]
    )

    # Store data in eval dictionary
    results["unconditional"] = {
        "median_distances": distances.median(dim=1).values.tolist(),
        "q1_distances": distances.quantile(0.25, dim=1).tolist(),
        "q3_distances": distances.quantile(0.75, dim=1).tolist(),
    }

    plot_distances_to_manifold(distances)

    model_name = config.names.score_model.split(".")[0]
    plot_filename = (
        f"{config.names.save}{model_name}_unconditional_"
        f"steps{config.sampler.num_steps}_"
        f"alpha{config.conditional.alpha}_"
        f"var{config.conditional.variance}.png"
    )
    plt.savefig(Path(config.outputs_dir) / plot_filename)
    plt.close()

    # Conditional generation
    iterates = torch.linspace(rewards.min(), rewards.max(), steps=8)
    reward_values = iterates.repeat_interleave(config.sampler.batch_size).unsqueeze(-1)

    reward_model.set_mode("sampling")

    batched_score_reward_model = ScoreRewardModel(
        condition=reward_values,
        reward_model=reward_model,
        score_model=score_model,
        variance=config.conditional.variance,
        alpha=config.conditional.alpha,
    )

    paths = sampler.sample(
        score_model=batched_score_reward_model,
        batch_size=config.sampler.batch_size * len(iterates),
        num_steps=config.sampler.num_steps,
    )

    reward_model.set_mode("feedback")

    obtained_rewards = (
        reward_model(paths[-1])
        .detach()
        .reshape(len(iterates), config.sampler.batch_size)
    )

    distances_to_manifold = torch.stack(
        [manifold_distance_fn(path, projector, radius, surface) for path in paths]
    )

    distances_to_manifold = distances_to_manifold.reshape(
        -1, config.sampler.batch_size, 8
    )

    # Store data in eval dictionary
    results["conditional"] = {
        "iterates": iterates.tolist(),
        "obtained_rewards": obtained_rewards.tolist(),
        "distances_to_manifold": distances_to_manifold.tolist(),
    }

    plot_conditional_generation(
        iterates, obtained_rewards, distances_to_manifold, config
    )

    eval_filename = (
        f"{config.names.save}{model_name}_conditional_"
        f"steps{config.sampler.num_steps}_"
        f"alpha{config.conditional.alpha}_"
        f"var{config.conditional.variance}.json"
    )
    with open(Path(config.outputs_dir) / eval_filename, "w", encoding="utf-8") as file:
        json.dump(results, file)


def plot_reward_histogram(rewards, radii, config):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))

    # Convert to numpy array if it's a torch tensor
    if hasattr(rewards, "cpu"):
        rewards = rewards.cpu().detach().numpy().squeeze()
    if hasattr(radii, "cpu"):
        radii = radii.cpu().detach().numpy().squeeze()

    # First subplot: 1D histogram
    ax1.hist(rewards, bins=50, density=True, edgecolor="black")
    ax1.set_xlabel("Reward Value")
    ax1.set_ylabel("Density")
    ax1.grid(True, alpha=0.3)

    # Add mean line
    mean_reward = rewards.mean()
    ax1.axvline(
        mean_reward,
        color="r",
        linestyle="dashed",
        linewidth=2,
        label=f"Mean: {mean_reward:.2f}",
    )
    ax1.legend()

    # Second subplot: 2D histogram
    print(rewards.shape, radii.shape)
    hist2d = ax2.hist2d(rewards, radii, bins=50, cmap="Blues")
    ax2.set_xlabel("Reward Value")
    ax2.set_ylabel("Radius")
    fig.colorbar(hist2d[3], ax=ax2, label="Count")

    # Adjust layout and save the plot
    plt.tight_layout()

    model_name = config.names.score_model.split(".")[0]
    plot_filename = (
        f"{config.names.save}{model_name}_reward_histogram"
        f"steps{config.sampler.num_steps}_"
        f"alpha{config.conditional.alpha}_"
        f"var{config.conditional.variance}.png"
    )
    plt.savefig(Path(config.outputs_dir) / plot_filename)
    plt.close()

    print(f"Reward histogram saved to {plot_filename}")


def plot_distances_to_manifold(distances):
    median_distances = distances.median(dim=1).values
    q1_distances = distances.quantile(0.05, dim=1)
    q3_distances = distances.quantile(0.95, dim=1)
    x_axis = torch.arange(len(median_distances))

    plt.figure(figsize=(10, 6))
    plt.plot(x_axis, median_distances, label="Median Distance")
    plt.plot(x_axis, torch.zeros_like(x_axis), label="Zero Line")
    plt.fill_between(x_axis, q1_distances, q3_distances, alpha=0.5, label="IQR")
    plt.ylabel(r"$\| x - \text{Proj}_{\Omega}[x]\|_2$")
    plt.xlabel("Index")
    plt.title("Distance to Manifold")
    plt.legend()
    plt.grid(True, alpha=0.3)


def plot_conditional_generation(iterates, rewards, distances_to_manifold, config):
    _, (ax1, ax2, ax3) = plt.subplots(
        3, 1, figsize=(6, 9), gridspec_kw={"height_ratios": [1, 1, 1]}
    )

    # Plot 1: Distance to manifold over time for each condition
    for i, iterate in enumerate(iterates):
        median_distance = distances_to_manifold[:, :, i].median(dim=1).values
        q1_distance = distances_to_manifold[:, :, i].quantile(0.05, dim=1)
        q3_distance = distances_to_manifold[:, :, i].quantile(0.95, dim=1)
        x_axis = torch.arange(len(median_distance))
        ax1.plot(x_axis, median_distance, label=f"Reward {iterate:.2f}")
        ax1.fill_between(x_axis, q1_distance, q3_distance, alpha=0.2)

    ax1.set_ylabel(r"$\| x - \text{Proj}_{\Omega}[x]\|_2$")
    ax1.set_yscale("log")
    ax1.set_xlabel(r"diffusion steps")
    ax1.set_title("Distance to Manifold Over Time")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Conditioned vs Obtained Reward
    median_rewards = rewards.median(dim=1).values / iterates.max()
    q1_rewards = rewards.quantile(0.05, dim=1) / iterates.max()
    q3_rewards = rewards.quantile(0.95, dim=1) / iterates.max()
    ax2.errorbar(
        iterates / iterates.max(),
        median_rewards,
        yerr=[median_rewards - q1_rewards, q3_rewards - median_rewards],
        fmt="o",
        capsize=5,
        label="Median standard rewards with 90CI",
        color="blue",
        alpha=0.7,
    )
    ax2.plot(
        iterates / iterates.max(),
        iterates / iterates.max(),
        "r--",
        alpha=0.4,
        label="y=x line",
    )
    ax2.set_ylabel("Obtained Reward")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Plot 3: Final Distance to Manifold
    final_distances = distances_to_manifold[-1]
    median_final_distances = final_distances.median(dim=0).values
    q1_final_distances = final_distances.quantile(0.05, dim=0)
    q3_final_distances = final_distances.quantile(0.95, dim=0)
    ax3.errorbar(
        iterates,
        median_final_distances,
        yerr=[
            median_final_distances - q1_final_distances,
            q3_final_distances - median_final_distances,
        ],
        fmt="o",
        capsize=5,
        label="Median final distance with IQR",
        color="green",
        alpha=0.7,
    )
    ax3.set_xlabel("Conditioned Reward")
    ax3.set_ylabel("Final Distance to Manifold")
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    plt.tight_layout()
    model_name = config.names.score_model
    plot_filename = (
        f"{config.names.save}{model_name}_conditional_"
        f"steps{config.sampler.num_steps}_"
        f"alpha{config.conditional.alpha}_"
        f"var{config.conditional.variance}.png"
    )
    plt.savefig(Path(config.outputs_dir) / plot_filename)
    plt.close()


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
